from collections import namedtuple
from typing import Callable, Optional

import lightning.pytorch as pl
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer


Batch = namedtuple("Batch", ["inputs", "labels"])


class SSLPretrainModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        loss_fn: Callable,
        optimizer: Optimizer,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    ):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

    def log_metrics(self, batch: Batch, output, loss, step_type):
        """Log useful information to TensorBoard. Users are expected to
        write their customized `log_metrics` method to log information
        such as loss values, metric scores, etc.

        Args:
            batch (Batch): Batch tuple from the dataloader.
            output: Output generated by the model.
            loss (Tensor): Generated class
            step_type (str): Type of step. Choices are "train", "val", and "test".
        """
        pass

    def training_step(self, batch: Batch, batch_idx):
        out = self.model(*batch.inputs)
        loss, num_frame = self.loss_fn(*out, *batch.labels)
        self.log_metric(batch, out, loss, "train")

        # normalize the loss based on the sum of num_frame across all GPUs
        num_frames = self.all_gather(num_frame)
        self.log(
            "Gathered number of frames",
            num_frames.float().sum(),
            on_step=True,
            on_epoch=True,
        )
        loss *= num_frames.size(0) / num_frames.sum()  # world size / num_frames

        return loss

    def validation_step(self, batch, batch_idx):
        out = self.model(*batch.inputs)
        loss, _ = self.loss_fn(*out, *batch.labels)
        self.log_metric(batch, out, loss, "val")
        return loss
